#include "megdnn/oprs/general.h"
#include "src/common/utils.h"

using namespace megdnn;

void ParamPackConcatSplitBase::check_exec(
        const TensorLayout& concated, const TensorLayout& offsets,
        const TensorLayout& parts) {
    megdnn_assert(
            offsets.dtype == dtype::Int32{}, "bad dtype: %s", offsets.dtype.name());
    megdnn_assert(
            concated.ndim == 1 && offsets.ndim == 1 && parts.ndim == 1 &&
                    concated.stride[0] == 1 && offsets.stride[0] == 1 &&
                    parts.stride[0] == 1,
            "bad layout: concated=%s offsets=%s parts=%s", concated.to_string().c_str(),
            offsets.to_string().c_str(), parts.to_string().c_str());
}

std::vector<dt_int32> ParamPackConcatSplitBase::gen_offsets(
        const TensorShapeArray& shapes, size_t alignment, size_t dtype_size) {
    megdnn_assert(
            alignment && (alignment & (alignment - 1)) == 0,
            "alignment must be power of 2: %zu", alignment);
    if (alignment < dtype_size)
        alignment = dtype_size;

    megdnn_assert(
            alignment % dtype_size == 0,
            "alignment must be multiple of dtype size: %zu vs %zu", alignment,
            dtype_size);
    alignment /= dtype_size;

    auto get_aligned = [alignment](size_t v) {
        auto mod = v & (alignment - 1);
        return v + ((alignment - mod) & (alignment - 1));
    };

    std::vector<dt_int32> offsets(shapes.size() << 1);
    size_t offset = 0;
    for (size_t i = 0; i < shapes.size(); i++) {
        offset = get_aligned(offset);
        offsets[i << 1] = offset;
        offset += shapes[i].total_nr_elems();
        offsets[(i << 1) + 1] = offset;
    }
    return offsets;
}

// vim: syntax=cpp.doxygen
